"""This module implements the ExpanderGraphLifting class."""

import warnings

import networkx
import torch
import torch_geometric

from topobench.transforms.liftings.graph2hypergraph.base import (
    Graph2HypergraphLifting,
)


class ExpanderGraphLifting(Graph2HypergraphLifting):
    r"""Lift graphs to expander (hyper)graph. More precisely, the expander is a random Ramanujan graph.

    Parameters
    ----------
    node_degree : int
        The desired node degree of the expander graph. Must be even.
    **kwargs : optional
        Additional arguments for the class.
    """

    def __init__(self, node_degree: int, **kwargs):
        super().__init__(**kwargs)

        assert node_degree % 2 == 0, "Only even node degree is supported."

        self.node_degree = node_degree

    def lift_topology(self, data: torch_geometric.data.Data) -> dict:
        r"""Lift the topology of a graph to an expander hypergraph.

        Parameters
        ----------
        data : torch_geometric.data.Data
            The input data to be lifted.

        Returns
        -------
        dict
            The lifted topology.
        """

        expander_graph = maybe_regular_expander(
            data.num_nodes, self.node_degree
        )

        # Catch superfluous warning
        with warnings.catch_warnings():
            warnings.simplefilter(action="ignore", category=FutureWarning)

            incidence_matrix = networkx.incidence_matrix(
                expander_graph
            ).tocoo()

        coo_indices = torch.stack(
            (
                torch.from_numpy(incidence_matrix.row),
                torch.from_numpy(incidence_matrix.col),
            )
        )
        coo_values = torch.from_numpy(
            incidence_matrix.data.astype("f4")
        )  # 4 bytes floating point number (single precision)

        incidence_matrix = torch.sparse_coo_tensor(coo_indices, coo_values)

        return {
            "incidence_hyperedges": incidence_matrix,
            "num_hyperedges": incidence_matrix.size(1),
            "x_0": data.x,
        }


"""
Random regular expander graphs are available from networkx >= 3.3 which currently conflicts dependencies. Thus we include the networkx
implementation here. After upgrade to networkx >= 3.3 this should be removed. Upgrading should also get rid of the FutureWarnings.
"""

if "random_regular_expander_graph" in networkx.generators.expanders.__all__:
    from networkx.generators.expanders import maybe_regular_expander

else:
    nx = networkx

    @nx.utils.decorators.np_random_state("seed")
    # @nx._dispatchable(graphs=None, returns_graph=True)
    def maybe_regular_expander(
        n, d, *, create_using=None, max_tries=100, seed=None
    ):
        r"""Utility for creating a random regular expander.

        Returns a random $d$-regular graph on $n$ nodes which is an expander
        graph with very good probability.

        Parameters
        ----------
        n : int
          The number of nodes.
        d : int
          The degree of each node.
        create_using : Graph Instance or Constructor
          Indicator of type of graph to return.
          If a Graph-type instance, then clear and use it.
          If a constructor, call it to create an empty graph.
          Use the Graph constructor by default.
        max_tries : int. (default: 100)
          The number of allowed loops when generating each independent cycle
        seed : (default: None)
          Seed used to set random number generation state. See :ref`Randomness<randomness>`.

        Notes
        -----
        The nodes are numbered from $0$ to $n - 1$.

        The graph is generated by taking $d / 2$ random independent cycles.

        Joel Friedman proved that in this model the resulting
        graph is an expander with probability
        $1 - O(n^{-\tau})$ where $\tau = \lceil (\sqrt{d - 1}) / 2 \rceil - 1$. [1]_

        Examples
        --------
        >>> G = nx.maybe_regular_expander(n=200, d=6, seed=8020)

        Returns
        -------
        G : graph
            The constructed undirected graph.

        Raises
        ------
        NetworkXError
            If $d % 2 != 0$ as the degree must be even.
            If $n - 1$ is less than $ 2d $ as the graph is complete at most.
            If max_tries is reached

        See Also
        --------
        is_regular_expander
        random_regular_expander_graph

        References
        ----------
        .. [1] Joel Friedman,
           A Proof of Alon's Second Eigenvalue Conjecture and Related Problems, 2004
           https://arxiv.org/abs/cs/0405020

        """

        # import numpy as np

        if n < 1:
            raise nx.NetworkXError("n must be a positive integer")

        if not (d >= 2):
            raise nx.NetworkXError("d must be greater than or equal to 2")

        if not (d % 2 == 0):
            raise nx.NetworkXError("d must be even")

        if not (n - 1 >= d):
            raise nx.NetworkXError(
                f"Need n-1>= d to have room for {d // 2} independent cycles with {n} nodes"
            )

        G = nx.empty_graph(n, create_using)

        if n < 2:
            return G

        cycles = []
        edges = set()

        # Create d / 2 cycles
        for i in range(d // 2):
            iterations = max_tries
            # Make sure the cycles are independent to have a regular graph
            while len(edges) != (i + 1) * n:
                iterations -= 1
                # Faster than random.permutation(n) since there are only
                # (n-1)! distinct cycles against n! permutations of size n
                cycle = seed.permutation(n - 1).tolist()
                cycle.append(n - 1)

                new_edges = {
                    (u, v)
                    for u, v in nx.utils.pairwise(cycle, cyclic=True)
                    if (u, v) not in edges and (v, u) not in edges
                }
                # If the new cycle has no edges in common with previous cycles
                # then add it to the list otherwise try again
                if len(new_edges) == n:
                    cycles.append(cycle)
                    edges.update(new_edges)

                if iterations == 0:
                    raise nx.NetworkXError(
                        "Too many iterations in maybe_regular_expander"
                    )

        G.add_edges_from(edges)

        return G
